import copy
import math
import random
import unittest

import torch
import torch.nn.functional as F

HAS_CONV_BIAS_RELU = None
try:
    from apex.contrib.conv_bias_relu import (
        ConvBiasReLU,
        ConvBias,
        ConvBiasMaskReLU,
        ConvFrozenScaleBiasReLU,
    )
except ImportError:
    HAS_CONV_BIAS_RELU = False
else:
    HAS_CONV_BIAS_RELU = True


@unittest.skipIf(not HAS_CONV_BIAS_RELU, "`apex.contrib.conv_bias_relu` is not found.")
class FusedDenseTest(unittest.TestCase):
    def setUp(self, seed=0):
        super().setUp()
        torch.manual_seed(seed)

        self.batch_size = random.randint(1, 64)
        self.in_channels = random.randint(1, 64) * 8
        self.out_channels = random.randint(1, 64) * 8
        self.in_height = self.in_width = random.randint(5, 100)
        self.conv_kernel_size = random.randint(1, 5)
        self.conv_pad = random.randint(0, int(self.conv_kernel_size / 2))
        self.conv_stride = random.randint(1, 5)
        self.conv_dilation = 1
        self.out_height = self.out_width = math.floor(
            (
                self.in_height
                + 2 * self.conv_pad
                - self.conv_dilation * (self.conv_kernel_size - 1)
                - 1
            )
            / self.conv_stride
            + 1
        )

        self.x = (
            torch.randint(
                low=-16,
                high=16,
                size=[self.batch_size, self.in_channels, self.in_height, self.in_width],
            )
            .cuda()
            .to(memory_format=torch.channels_last)
            .float()
        )
        self.x_ = self.x.clone()
        self.x.requires_grad_()
        self.x_.requires_grad_()

        self.mask = (
            torch.randn([self.batch_size, self.out_channels, self.out_height, self.out_width])
            .cuda()
            .to(memory_format=torch.channels_last)
        )
        self.mask = (self.mask > 0).to(torch.int8)
        self.mask_ = self.mask.clone()

        self.scale = torch.randn([1, self.out_channels, 1, 1]).half().cuda()
        self.scale_ = self.scale.clone()
        self.bias = torch.randn([1, self.out_channels, 1, 1]).half().cuda()
        self.bias_ = self.bias.clone()

        self.conv1 = (
            torch.nn.Conv2d(
                self.in_channels,
                self.out_channels,
                self.conv_kernel_size,
                stride=self.conv_stride,
                padding=self.conv_pad,
            )
            .cuda()
            .to(memory_format=torch.channels_last)
        )
        self.conv1_ = copy.deepcopy(self.conv1)

        self.conv2 = (
            torch.nn.Conv2d(
                self.in_channels,
                self.out_channels,
                self.conv_kernel_size,
                stride=self.conv_stride,
                padding=self.conv_pad,
                bias=False,
            )
            .cuda()
            .to(memory_format=torch.channels_last)
        )
        self.conv2_ = copy.deepcopy(self.conv2)

        print()
        print(
            "> input=[{}, {}, {}, {}]".format(
                self.batch_size, self.in_channels, self.in_height, self.in_width
            )
        )
        print(
            "> kernel=[{}, {}, {}, {}], stride={}, pad={}".format(
                self.out_channels,
                self.in_channels,
                self.conv_kernel_size,
                self.conv_kernel_size,
                self.conv_stride,
                self.conv_pad,
            )
        )

    def test_conv_bias_relu(self):
        with torch.amp.autocast("cuda", dtype=torch.half):
            out = ConvBiasReLU(
                self.x,
                self.conv1.weight,
                self.conv1.bias.reshape(1, -1, 1, 1),
                self.conv_pad,
                self.conv_stride,
            )
            loss = (out.float() ** 2).sum() / out.numel()
        loss.backward()
        with torch.amp.autocast("cuda", dtype=torch.half):
            out_ = F.relu(self.conv1_(self.x_))
            loss_ = (out_**2).sum() / out_.numel()
        loss_.backward()

        torch.testing.assert_close(out_, out, atol=1e-3, rtol=1e-3, equal_nan=True)
        torch.testing.assert_close(
            self.conv1_.bias.grad,
            self.conv1.bias.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(
            self.conv1_.weight.grad,
            self.conv1.weight.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)

    def test_conv_bias(self):
        with torch.amp.autocast("cuda", dtype=torch.half):
            out = ConvBias(
                self.x,
                self.conv1.weight,
                self.conv1.bias.reshape(1, -1, 1, 1),
                self.conv_pad,
                self.conv_stride,
            )
            loss = (out.float() ** 2).sum() / out.numel()
        loss.backward()

        with torch.amp.autocast("cuda", dtype=torch.half):
            out_ = self.conv1_(self.x_)
            loss_ = (out_**2).sum() / out_.numel()
        loss_.backward()

        torch.testing.assert_close(out, out_, atol=1e-3, rtol=1e-3, equal_nan=True)
        torch.testing.assert_close(
            self.conv1_.bias.grad,
            self.conv1.bias.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(
            self.conv1_.weight.grad,
            self.conv1.weight.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)

    def test_conv_bias_mask_relu(self):
        with torch.amp.autocast("cuda", dtype=torch.half):
            out = ConvBiasMaskReLU(
                self.x,
                self.conv1.weight,
                self.conv1.bias.reshape(1, -1, 1, 1),
                self.mask,
                self.conv_pad,
                self.conv_stride,
            )
            loss = (out.float() ** 2).sum() / out.numel()
        loss.backward()
        with torch.amp.autocast("cuda", dtype=torch.half):
            out_ = F.relu(self.conv1_(self.x_) * self.mask_)
            loss_ = (out_**2).sum() / out_.numel()
        loss_.backward()

        torch.testing.assert_close(out, out_, atol=1e-3, rtol=1e-3, equal_nan=True)
        torch.testing.assert_close(
            self.conv1_.bias.grad,
            self.conv1.bias.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(
            self.conv1_.weight.grad,
            self.conv1.weight.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)

    def test_conv_frozen_scale_bias_relu(self):
        with torch.amp.autocast("cuda", dtype=torch.half):
            out = ConvFrozenScaleBiasReLU(
                self.x,
                self.conv2.weight,
                self.scale,
                self.bias,
                self.conv_pad,
                self.conv_stride,
            )
            loss = (out.float() ** 2).sum() / out.numel()
        loss.backward()
        with torch.amp.autocast("cuda", dtype=torch.half):
            out_ = F.relu(self.conv2_(self.x_) * self.scale_ + self.bias_)
            loss_ = (out_**2).sum() / out_.numel()
        loss_.backward()

        torch.testing.assert_close(out, out_, atol=2.5e-3, rtol=2.5e-3, equal_nan=True)
        torch.testing.assert_close(
            self.conv2_.weight.grad,
            self.conv2.weight.grad,
            atol=1e-3,
            rtol=1e-3,
            equal_nan=True,
        )
        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)


if __name__ == "__main__":
    unittest.main()
